#!/usr/bin/env python3
# Copyright European Organization for Nuclear Research (CERN) since 2012
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Utility script to generate file headers.

The script should exit with a non-zero exit code if a header is wrong or fixed,
to let the user know about a possible change. A general file type, defining
which files should get headers, has to be a class inherited from the
`HeaderTemplate` abstract class. The class gets registered automatically and is
accessible through the `header_template_factory` function. The implementation
for each file gets selected by the `is_file_suitable` method provided by the
template. Other methods to get information about the currently used header and
the new header have to be provided too.
"""

import argparse
import logging
import os
import sys
from abc import ABC, abstractmethod
from collections.abc import Iterator, ValuesView
from enum import Enum, auto
from typing import Optional

try:
    from tqdm import tqdm
except ImportError:
    pass


class Arguments:
    """ All arguments for the executable. """
    files: Optional[list[str]]
    dry_run: bool
    disable_progress_bar: bool


class ResultStates(Enum):
    """ All states a file can be in after Processing. """
    FileDoesNotExist = auto()
    FileTemplateNotFound = auto()
    HeaderIsCorrect = auto()
    HeaderIsIncorrect = auto()
    HeaderModified = auto()

    @staticmethod
    def get_exit_code(state) -> int:
        """
        Returns the exit code corresponding to the state.

        :param state: The state for which the exit code should be determined.
        :returns: The exit code corresponding to the state.
        """
        if state not in ResultStates:
            raise Exception(f"Provided element {state} is not a ResultState enum element.")

        if state in (ResultStates.HeaderIsIncorrect, ResultStates.HeaderModified):
            return 1
        elif state == ResultStates.FileDoesNotExist:
            return 2
        return 0


def exit_code_from_result_states(states: ValuesView[ResultStates]) -> int:
    """
    Returns the exit code corresponding to the provided ResultStates.

    :param states: A list with all ResultStates.
    :returns: The highest exit code. This will be the most "important" one.
    :throws Exception: States should not be an empty array.
    """
    if len(states) == 0:
        raise Exception("States should not be empty!")
    return max(map(lambda state: ResultStates.get_exit_code(state), states))


def print_result_states(states: dict[str, ResultStates]) -> None:
    """
    Prints the result states in a readable format for the user.

    :param states: A dict with a file name as key and the corresponding ResultState.
    """
    state_buckets = {}
    for s in ResultStates:
        state_buckets[s] = [path for path, state in states.items() if state == s]

    for f in state_buckets[ResultStates.HeaderModified]:
        logging.info("Header modified in file %s", f)
    for f in state_buckets[ResultStates.HeaderIsIncorrect]:
        logging.info("Header is incorrect in file %s", f)
    for f in state_buckets[ResultStates.FileTemplateNotFound]:
        logging.info("File Template not found for file %s", f)
    for f in state_buckets[ResultStates.FileDoesNotExist]:
        logging.info("File %s not found", f)

    logging.info(
        "Processed %d files, modified %d, got %d files with problems",
        len(states),
        len(state_buckets[ResultStates.HeaderModified]),
        len(state_buckets[ResultStates.HeaderModified]) + len(state_buckets[ResultStates.HeaderIsIncorrect])
    )


class HeaderTemplateNotFoundException(Exception):
    """ Exception if a suitable template is not found for the current template. """
    pass


class HeaderTemplate(ABC):
    """ Abstract class representing a header generator for a type of files. """
    subtemplates: list[type] = []

    def __init_subclass__(cls, **kwargs):
        """
        Adds all subclasses to the `subtemplate` array. This allows us to
        iterate over all subclasses and determine the right template in the
        factory.

        https://peps.python.org/pep-0487/#new-ways-of-using-classes
        """
        super().__init_subclass__(**kwargs)
        cls.subtemplates.append(cls)

    @staticmethod
    @abstractmethod
    def is_file_suitable(file_path: str) -> bool:
        """
        Determines if a file is suitable for this template. Every file should
        only get accepted by one file.

        :param file_path: The path to the file.
        :returns: True if the template can be used for the file, False otherwise.
        """
        pass

    def __init__(self, file_path: str):
        """
        Constructor.

        :param file_path: Path to the file for which the template gets created.
        """
        self.file_path = file_path

    @abstractmethod
    def get_header(self) -> str:
        """
        Returns the header for the template.

        :returns: The header for the template.
        """
        pass

    @abstractmethod
    def current_header_number_of_lines(self) -> int:
        """
        Gets the number of lines for the current header in the file.

        :returns: Number of lines of the current header in the file.
        """
        pass

    def is_header_correct(self) -> bool:
        """
        Checks if the header needs to be modified.

        :returns: True if the header is present and correct, False otherwise.
        """
        with open(self.file_path, "r") as read_obj:
            file_content = read_obj.readlines()

        return self.get_header() in "".join(file_content)


class BashHeaderTemplate(HeaderTemplate):
    @staticmethod
    def _get_file_shebag(file_path: str) -> Optional[str]:
        """
        Returns the bash file shebag.

        :param file_path: The path to the file.
        :returns: The python file shebag if present. This includes a new line character.
        :throws UnicodeDecodeError: If the file could not be read.
        """
        with open(file_path, "r") as file_obj:
            file_content = file_obj.readlines()

        if len(file_content) == 0:
            return None

        if "#!/bin/sh" in file_content[0]:
            return file_content[0]

        if "#!/bin/bash" in file_content[0]:
            return file_content[0]

        return None

    @staticmethod
    def is_file_suitable(file_path: str) -> bool:
        if PythonHeaderTemplate._should_file_be_ignored(file_path):
            return False

        try:
            is_file_using_shebag = BashHeaderTemplate._get_file_shebag(file_path) is not None
        except UnicodeDecodeError:
            # File could not be opened, this makes it automatically invalid
            return False

        return os.path.splitext(file_path)[1] == ".sh" or is_file_using_shebag

    def __init__(self, file_path: str):
        super().__init__(file_path)

    def get_header(self) -> str:
        shebag_string = BashHeaderTemplate._get_file_shebag(self.file_path) or ""

        return shebag_string \
            + """# -*- coding: utf-8 -*-
# Copyright European Organization for Nuclear Research (CERN) since 2012
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

    def current_header_number_of_lines(self) -> int:
        with open(self.file_path, "r") as file_obj:
            file_content = file_obj.readlines()

        ret = 0
        for line in file_content:
            if not line.startswith("#"):
                break
            ret += 1
        return ret


class PythonHeaderTemplate(HeaderTemplate):
    @staticmethod
    def _should_file_be_ignored(file_path: str) -> bool:
        """
        Returns True if the file should be ignored. E.g. if it's auto-generated.

        :param file_path: The path to the file.
        :returns: True if the file should be ignored, False otherwise.
        :throws UnicodeDecodeError: If the file could not be read.
        """
        return "lib/rucio/vcsversion.py" in file_path

    @staticmethod
    def _get_file_shebag(file_path: str) -> Optional[str]:
        """
        Returns the python file shebag.

        :param file_path: The path to the file.
        :returns: The python file shebag if present. This includes a new line character.
        :throws UnicodeDecodeError: If the file could not be read.
        """
        with open(file_path, "r") as file_obj:
            file_content = file_obj.readlines()

        if len(file_content) == 0:
            return None
        if "#!/usr/bin/env python" in file_content[0]:
            return file_content[0]
        return None

    @staticmethod
    def is_file_suitable(file_path: str) -> bool:
        if PythonHeaderTemplate._should_file_be_ignored(file_path):
            return False

        try:
            is_file_using_shebag = PythonHeaderTemplate._get_file_shebag(file_path) is not None
        except UnicodeDecodeError:
            # File could not be opened, this makes it automatically invalid
            return False

        return file_path.endswith(".py") or file_path.endswith(".py.mako") or is_file_using_shebag

    def __init__(self, file_path: str):
        super().__init__(file_path)

    def get_header(self) -> str:
        shebag_string = PythonHeaderTemplate._get_file_shebag(self.file_path) or ""

        return shebag_string \
            + """# Copyright European Organization for Nuclear Research (CERN) since 2012
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

    def current_header_number_of_lines(self) -> int:
        with open(self.file_path, "r") as file_obj:
            file_content = file_obj.readlines()

        ret = 0
        for line in file_content:
            if not line.startswith("#"):
                break
            ret += 1
        return ret


def header_template_factory(file_path: str) -> HeaderTemplate:
    """
    Returns the correct template depending on the file.

    :param file_path: Path to the file for which the template gets created.
    :returns: The HeaderTemplate for the file.
    :throws HeaderTemplateNotFoundException: If no HeaderTemplate matches the file.
    """
    for t in HeaderTemplate.subtemplates:
        if t.is_file_suitable(file_path):
            return t(file_path)
    raise HeaderTemplateNotFoundException(
        f"File template not found! Is a template implemented for {file_path}?"
    )


def template_exists_for_file(file_path: str) -> bool:
    """
    Tests if there is a suitable template for the file.

    :param file_path: Path to the file.
    :returns: True is a HeaderTemplate exists for the file, False otherwise.
    """
    try:
        header_template_factory(file_path)
        return True
    except HeaderTemplateNotFoundException:
        return False


def get_all_files_recursive(path: str) -> Iterator[str]:
    """
    Returns all files in the project, known to the versioning system. This
    includes the files in subdirectories.

    :param path: The path which should contain the files.
    :returns: An Iterator over all files found in the VS.
    """
    command = f"cd {path} && git ls-files"
    command_out = os.popen(command).read().strip()

    if not command_out:
        logging.warning("This does not seem to be a git project or no files found. We won't traverse the folder here.")
        return

    for f in command_out.split("\n"):
        yield f


def get_all_suitable_files_recursive(path: str) -> Iterator[str]:
    """
    Returns all files for which a template exist and which are in the VS. This
    includes the files in subdirectories.

    :param path: The path which should contain the files.
    :returns: An Iterator over all suitable files found.
    """
    for f in get_all_files_recursive(path):
        if template_exists_for_file(f):
            yield f


def modify_header(header_template: HeaderTemplate) -> None:
    """
    Modifies the header. This deletes the old one and adds the new one.

    :param header_template: The HeaderTemplate which generates the new header
                            and determines the old one for the file.
    """
    with open(header_template.file_path, "r") as read_obj:
        file_content_lines = read_obj.readlines()

    file_content_lines = file_content_lines[header_template.current_header_number_of_lines():]
    file_content = header_template.get_header() + "".join(file_content_lines)

    with open(header_template.file_path, "w") as write_obj:
        write_obj.write(file_content)


def process_file(file_path: str, dry_run: bool) -> ResultStates:
    """
    Processes one file. Modifies the file if necessary.

    :param file_path: Path to the file.
    :param dry_run: If True, the header won't be changed.
    :returns: A ResultState corresponding to the action performed on or the
              state of the file.
    """
    if not os.path.exists(file_path):
        return ResultStates.FileDoesNotExist

    try:
        header_template = header_template_factory(file_path)
    except HeaderTemplateNotFoundException:
        return ResultStates.FileTemplateNotFound

    if header_template.is_header_correct():
        return ResultStates.HeaderIsCorrect

    if dry_run:
        return ResultStates.HeaderIsIncorrect

    modify_header(header_template)
    return ResultStates.HeaderModified


def process_files(files: list[str], dry_run: bool, disable_progress_bar: bool) -> dict[str, ResultStates]:
    """
    Processes all files and prints the result.

    :param files: The files to process.
    :param dry_run: If True, the headers won't be changed.
    :param disable_progress_bar: If True, disables the progress bar.
    :returns: A dictionary mapping the files to the corresponding ResultStates
              they are in.
    """
    result_states: dict[str, ResultStates] = dict()

    if not disable_progress_bar and 'tqdm' in sys.modules:
        # Print a progress bar if tqdm is installed
        files = tqdm(files)

    for file_path in files:
        result_states[file_path] = process_file(file_path, dry_run)

    return result_states


def main(arguments: Arguments):
    if arguments.files:
        files = arguments.files
    else:
        files = list(get_all_suitable_files_recursive("."))
    if not files:
        logging.info("No files provided and no suitable files found in the current directory. Nothing to do.")
        return 0

    result_states = process_files(files, arguments.dry_run, arguments.disable_progress_bar)
    print_result_states(result_states)

    return exit_code_from_result_states(result_states.values())


if __name__ == '__main__':
    logging.basicConfig(format="%(message)s", level=os.environ.get("LOGLEVEL", "INFO"))

    parser = argparse.ArgumentParser(description="""Program to sanitize the licence headers in the repository.

If a file header is incorrect or modified, it has a non-zero return code.""",
                                     epilog="""examples:
  add_header                  Runs the script on every suitable file. This could take some time.
  add_header --dry-run        Runs the script without modifying anything. This could take some time.
  add_header FILE [FILES]     Runs the script against the specified files.""",
                                     formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument(dest='files', action='store', nargs='*',
                        help='The files to check. Leave empty to check all files in the repository.')
    parser.add_argument('-d', '--dry-run', action='store_true',
                        help='Dry run mode, does not change anything')
    parser.add_argument('--disable-progress-bar', action='store_true',
                        help='Disables the progress bar.')

    arguments = Arguments()
    parser.parse_args(namespace=Arguments)
    sys.exit(main(arguments))
